{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  图像分类模型通道剪裁-敏感度分析\n",
    "\n",
    "该教程以图像分类模型MobileNetV1为例，说明如何快速使用[PaddleSlim的敏感度分析接口](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#sensitivity)。\n",
    "该示例包含以下步骤：\n",
    "\n",
    "1. 导入依赖\n",
    "2. 构建模型\n",
    "3. 定义输入数据\n",
    "4. 定义模型评估方法\n",
    "5. 训练模型\n",
    "6. 获取待分析卷积参数名称\n",
    "7. 分析敏感度\n",
    "8. 剪裁模型\n",
    "\n",
    "以下章节依次次介绍每个步骤的内容。\n",
    "\n",
    "## 1. 导入依赖\n",
    "\n",
    "PaddleSlim依赖Paddle1.7版本，请确认已正确安装Paddle，然后按以下方式导入Paddle和PaddleSlim:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "import paddleslim as slim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 构建网络\n",
    "\n",
    "该章节构造一个用于对MNIST数据进行分类的分类模型，选用`MobileNetV1`，并将输入大小设置为`[1, 28, 28]`，输出类别数为10。\n",
    "为了方便展示示例，我们在`paddleslim.models`下预定义了用于构建分类模型的方法，执行以下代码构建分类模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "exe, train_program, val_program, inputs, outputs = slim.models.image_classification(\"MobileNet\", [1, 28, 28], 10, use_gpu=True)\n",
    "place = fluid.CUDAPlace(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 定义输入数据\n",
    "\n",
    "为了快速执行该示例，我们选取简单的MNIST数据，Paddle框架的`paddle.dataset.mnist`包定义了MNIST数据的下载和读取。\n",
    "代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle.dataset.mnist as reader\n",
    "train_reader = paddle.fluid.io.batch(\n",
    "        reader.train(), batch_size=128, drop_last=True)\n",
    "test_reader = paddle.fluid.io.batch(\n",
    "        reader.test(), batch_size=128, drop_last=True)\n",
    "data_feeder = fluid.DataFeeder(inputs, place)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 定义模型评估方法\n",
    "\n",
    "在计算敏感度时，需要裁剪单个卷积层后的模型在测试数据上的效果，我们定义以下方法实现该功能："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def test(program):\n",
    "    acc_top1_ns = []\n",
    "    acc_top5_ns = []\n",
    "    for data in test_reader():\n",
    "        acc_top1_n, acc_top5_n, _ = exe.run(\n",
    "            program,\n",
    "            feed=data_feeder.feed(data),\n",
    "            fetch_list=outputs)\n",
    "        acc_top1_ns.append(np.mean(acc_top1_n))\n",
    "        acc_top5_ns.append(np.mean(acc_top5_n))\n",
    "    print(\"Final eva - acc_top1: {}; acc_top5: {}\".format(\n",
    "        np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))\n",
    "    return np.mean(np.array(acc_top1_ns))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 训练模型\n",
    "\n",
    "只有训练好的模型才能做敏感度分析，因为该示例任务相对简单，我这里用训练一个`epoch`产出的模型做敏感度分析。对于其它训练比较耗时的模型，您可以加载训练好的模型权重。\n",
    "\n",
    "以下为模型训练代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.984375 1.0 0.04038039\n"
     ]
    }
   ],
   "source": [
    "for data in train_reader():\n",
    "    acc1, acc5, loss = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs)\n",
    "print(np.mean(acc1), np.mean(acc5), np.mean(loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用上节定义的模型评估方法，评估当前模型在测试集上的精度："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9574319124221802; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9574319"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(val_program)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 获取待分析卷积参数\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['conv2_1_sep_weights', 'conv2_2_sep_weights', 'conv3_1_sep_weights', 'conv3_2_sep_weights', 'conv4_1_sep_weights', 'conv4_2_sep_weights', 'conv5_1_sep_weights', 'conv5_2_sep_weights', 'conv5_3_sep_weights', 'conv5_4_sep_weights', 'conv5_5_sep_weights', 'conv5_6_sep_weights', 'conv6_sep_weights']\n"
     ]
    }
   ],
   "source": [
    "params = []\n",
    "for param in train_program.global_block().all_parameters():\n",
    "    if \"_sep_weights\" in param.name:\n",
    "        params.append(param.name)\n",
    "print(params)\n",
    "params = params[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. 分析敏感度\n",
    "\n",
    "### 7.1 简单计算敏感度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "调用[sensitivity接口](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#sensitivity)对训练好的模型进行敏感度分析。\n",
    "\n",
    "在计算过程中，敏感度信息会不断追加保存到选项`sensitivities_file`指定的文件中，该文件中已有的敏感度信息不会被重复计算。\n",
    "\n",
    "先用以下命令删除当前路径下可能已有的`sensitivities_0.data`文件:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf sensitivities_0.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "除了指定待分析的卷积层参数，我们还可以指定敏感度分析的粒度和范围，即单个卷积层参数分别被剪裁掉的比例。\n",
    "\n",
    "如果待分析的模型比较敏感，剪掉单个卷积层的40%的通道，模型在测试集上的精度损失就达90%，那么`pruned_ratios`最大设置到0.4即可，比如：\n",
    "`[0.1, 0.2, 0.3, 0.4]`\n",
    "\n",
    "为了得到更精确的敏感度信息，我可以适当调小`pruned_ratios`的粒度，比如：`[0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]`\n",
    "\n",
    "`pruned_ratios`的粒度越小，计算敏感度的速度越慢。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:33,091-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9574319124221802; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:35,971-INFO: pruned param: conv2_2_sep_weights; 0.1; loss=0.025107262656092644\n",
      "2020-02-04 15:29:35,975-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9333934187889099; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:38,797-INFO: pruned param: conv2_2_sep_weights; 0.2; loss=0.04069465771317482\n",
      "2020-02-04 15:29:38,801-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9184695482254028; acc_top5: 0.9983974099159241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:42,056-INFO: pruned param: conv2_1_sep_weights; 0.1; loss=0.035987019538879395\n",
      "2020-02-04 15:29:42,059-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9229767918586731; acc_top5: 0.9989984035491943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:45,121-INFO: pruned param: conv2_1_sep_weights; 0.2; loss=0.031697917729616165\n",
      "2020-02-04 15:29:45,124-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9270833134651184; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:48,070-INFO: pruned param: conv3_1_sep_weights; 0.1; loss=-0.00010458791075507179\n",
      "2020-02-04 15:29:48,073-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9575320482254028; acc_top5: 0.9992988705635071\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:51,172-INFO: pruned param: conv3_1_sep_weights; 0.2; loss=0.004707638639956713\n",
      "2020-02-04 15:29:51,174-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9529246687889099; acc_top5: 0.9993990659713745\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:54,379-INFO: pruned param: conv4_1_sep_weights; 0.1; loss=0.0015692544402554631\n",
      "2020-02-04 15:29:54,382-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9559294581413269; acc_top5: 0.9993990659713745\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:29:57,316-INFO: pruned param: conv4_1_sep_weights; 0.2; loss=0.001987668452784419\n",
      "2020-02-04 15:29:57,319-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9555288553237915; acc_top5: 0.9989984035491943\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:00,300-INFO: pruned param: conv3_2_sep_weights; 0.1; loss=-0.005021402612328529\n",
      "2020-02-04 15:30:00,306-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9622395634651184; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:03,400-INFO: pruned param: conv3_2_sep_weights; 0.2; loss=0.0008369522984139621\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9566305875778198; acc_top5: 0.9991987347602844\n",
      "{'conv2_2_sep_weights': {0.1: 0.025107263, 0.2: 0.040694658}, 'conv2_1_sep_weights': {0.1: 0.03598702, 0.2: 0.031697918}, 'conv3_1_sep_weights': {0.1: -0.00010458791, 0.2: 0.0047076386}, 'conv4_1_sep_weights': {0.1: 0.0015692544, 0.2: 0.0019876685}, 'conv3_2_sep_weights': {0.1: -0.0050214026, 0.2: 0.0008369523}}\n"
     ]
    }
   ],
   "source": [
    "sens_0 = slim.prune.sensitivity(\n",
    "        val_program,\n",
    "        place,\n",
    "        params,\n",
    "        test,\n",
    "        sensitivities_file=\"sensitivities_0.data\",\n",
    "        pruned_ratios=[0.1, 0.2])\n",
    "print(sens_0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 扩展敏感度信息\n",
    "\n",
    "第7.1节计算敏感度用的是`pruned_ratios=[0.1, 0.2]`, 我们可以在此基础上将其扩展到`[0.1, 0.2, 0.3]`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:16,173-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9574319124221802; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:19,087-INFO: pruned param: conv2_2_sep_weights; 0.3; loss=0.2279527187347412\n",
      "2020-02-04 15:30:19,091-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.739182710647583; acc_top5: 0.9918870329856873\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:22,079-INFO: pruned param: conv2_1_sep_weights; 0.3; loss=0.08871221542358398\n",
      "2020-02-04 15:30:22,082-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.8724960088729858; acc_top5: 0.9975961446762085\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:24,974-INFO: pruned param: conv3_1_sep_weights; 0.3; loss=0.005439940840005875\n",
      "2020-02-04 15:30:24,976-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.952223539352417; acc_top5: 0.999098539352417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:28,071-INFO: pruned param: conv4_1_sep_weights; 0.3; loss=0.03535936772823334\n",
      "2020-02-04 15:30:28,073-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9235777258872986; acc_top5: 0.9978966116905212\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-04 15:30:31,068-INFO: pruned param: conv3_2_sep_weights; 0.3; loss=0.008055261336266994\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9497195482254028; acc_top5: 0.9986979365348816\n",
      "{'conv2_2_sep_weights': {0.1: 0.025107263, 0.2: 0.040694658, 0.3: 0.22795272}, 'conv2_1_sep_weights': {0.1: 0.03598702, 0.2: 0.031697918, 0.3: 0.088712215}, 'conv3_1_sep_weights': {0.1: -0.00010458791, 0.2: 0.0047076386, 0.3: 0.005439941}, 'conv4_1_sep_weights': {0.1: 0.0015692544, 0.2: 0.0019876685, 0.3: 0.035359368}, 'conv3_2_sep_weights': {0.1: -0.0050214026, 0.2: 0.0008369523, 0.3: 0.008055261}}\n"
     ]
    }
   ],
   "source": [
    "sens_0 = slim.prune.sensitivity(\n",
    "        val_program,\n",
    "        place,\n",
    "        params,\n",
    "        test,\n",
    "        sensitivities_file=\"sensitivities_0.data\",\n",
    "        pruned_ratios=[0.3])\n",
    "print(sens_0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.3 多进程加速计算敏感度信息\n",
    "\n",
    "敏感度分析所用时间取决于待分析的卷积层数量和模型评估的速度，我们可以通过多进程的方式加速敏感度计算。\n",
    "\n",
    "在不同的进程设置不同`pruned_ratios`, 然后将结果合并。\n",
    "\n",
    "#### 7.3.1 多进程计算敏感度\n",
    "\n",
    "在以上章节，我们计算了`pruned_ratios=[0.1, 0.2, 0.3]`的敏感度，并将其保存到了文件`sensitivities_0.data`中。\n",
    "\n",
    "在另一个进程中，我们可以设置`pruned_ratios=[0.4]`，并将结果保存在文件`sensitivities_1.data`中。代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'conv2_2_sep_weights': {0.4: 0.06348718}, 'conv2_1_sep_weights': {0.4: 0.15917951}, 'conv4_1_sep_weights': {0.4: 0.16246155}, 'conv3_1_sep_weights': {0.4: 0.034871764}, 'conv3_2_sep_weights': {0.4: 0.115384646}}\n"
     ]
    }
   ],
   "source": [
    "sens_1 = slim.prune.sensitivity(\n",
    "        val_program,\n",
    "        place,\n",
    "        params,\n",
    "        test,\n",
    "        sensitivities_file=\"sensitivities_1.data\",\n",
    "        pruned_ratios=[0.4])\n",
    "print(sens_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 7.3.2 加载多个进程产出的敏感度文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'conv2_2_sep_weights': {0.1: 0.025107263, 0.2: 0.040694658, 0.3: 0.22795272}, 'conv2_1_sep_weights': {0.1: 0.03598702, 0.2: 0.031697918, 0.3: 0.088712215}, 'conv3_1_sep_weights': {0.1: -0.00010458791, 0.2: 0.0047076386, 0.3: 0.005439941}, 'conv4_1_sep_weights': {0.1: 0.0015692544, 0.2: 0.0019876685, 0.3: 0.035359368}, 'conv3_2_sep_weights': {0.1: -0.0050214026, 0.2: 0.0008369523, 0.3: 0.008055261}}\n",
      "{'conv2_2_sep_weights': {0.4: 0.06348718}, 'conv2_1_sep_weights': {0.4: 0.15917951}, 'conv4_1_sep_weights': {0.4: 0.16246155}, 'conv3_1_sep_weights': {0.4: 0.034871764}, 'conv3_2_sep_weights': {0.4: 0.115384646}}\n"
     ]
    }
   ],
   "source": [
    "s_0 = slim.prune.load_sensitivities(\"sensitivities_0.data\")\n",
    "s_1 = slim.prune.load_sensitivities(\"sensitivities_1.data\")\n",
    "print(s_0)\n",
    "print(s_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 7.3.3 合并敏感度信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'conv2_2_sep_weights': {0.1: 0.025107263, 0.2: 0.040694658, 0.3: 0.22795272, 0.4: 0.06348718}, 'conv2_1_sep_weights': {0.1: 0.03598702, 0.2: 0.031697918, 0.3: 0.088712215, 0.4: 0.15917951}, 'conv3_1_sep_weights': {0.1: -0.00010458791, 0.2: 0.0047076386, 0.3: 0.005439941, 0.4: 0.034871764}, 'conv4_1_sep_weights': {0.1: 0.0015692544, 0.2: 0.0019876685, 0.3: 0.035359368, 0.4: 0.16246155}, 'conv3_2_sep_weights': {0.1: -0.0050214026, 0.2: 0.0008369523, 0.3: 0.008055261, 0.4: 0.115384646}}\n"
     ]
    }
   ],
   "source": [
    "s = slim.prune.merge_sensitive([s_0, s_1])\n",
    "print(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. 剪裁模型\n",
    "\n",
    "根据以上章节产出的敏感度信息，对模型进行剪裁。\n",
    "\n",
    "### 8.1 计算剪裁率\n",
    "\n",
    "首先，调用PaddleSlim提供的[get_ratios_by_loss](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#get_ratios_by_loss)方法根据敏感度计算剪裁率，通过调整参数`loss`大小获得合适的一组剪裁率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'conv3_1_sep_weights': 0.3, 'conv4_1_sep_weights': 0.22400936122727166, 'conv3_2_sep_weights': 0.3}\n"
     ]
    }
   ],
   "source": [
    "loss = 0.01\n",
    "ratios = slim.prune.get_ratios_by_loss(s_0, loss)\n",
    "print(ratios)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.2 剪裁训练网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FLOPs before pruning: 10896832.0\n",
      "FLOPs after pruning: 9777980.0\n"
     ]
    }
   ],
   "source": [
    "pruner = slim.prune.Pruner()\n",
    "print(\"FLOPs before pruning: {}\".format(slim.analysis.flops(train_program)))\n",
    "pruned_program, _, _ = pruner.prune(\n",
    "        train_program,\n",
    "        fluid.global_scope(),\n",
    "        params=ratios.keys(),\n",
    "        ratios=ratios.values(),\n",
    "        place=place)\n",
    "print(\"FLOPs after pruning: {}\".format(slim.analysis.flops(pruned_program)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.3 剪裁测试网络\n",
    "\n",
    ">注意：对测试网络进行剪裁时，需要将`only_graph`设置为True，具体原因请参考[Pruner API文档](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/#pruner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FLOPs before pruning: 10896832.0\n",
      "FLOPs after pruning: 9777980.0\n"
     ]
    }
   ],
   "source": [
    "pruner = slim.prune.Pruner()\n",
    "print(\"FLOPs before pruning: {}\".format(slim.analysis.flops(val_program)))\n",
    "pruned_val_program, _, _ = pruner.prune(\n",
    "        val_program,\n",
    "        fluid.global_scope(),\n",
    "        params=ratios.keys(),\n",
    "        ratios=ratios.values(),\n",
    "        place=place,\n",
    "        only_graph=True)\n",
    "print(\"FLOPs after pruning: {}\".format(slim.analysis.flops(pruned_val_program)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试一下剪裁后的模型在测试集上的精度："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9721554517745972; acc_top5: 0.9995993375778198\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.97215545"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(pruned_val_program)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8.4 训练剪裁后的模型\n",
    "\n",
    "对剪裁后的模型在训练集上训练一个`epoch`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.984375 1.0 0.04675974\n"
     ]
    }
   ],
   "source": [
    "for data in train_reader():\n",
    "    acc1, acc5, loss = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs)\n",
    "print(np.mean(acc1), np.mean(acc5), np.mean(loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试训练后模型的精度："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final eva - acc_top1: 0.9721554517745972; acc_top5: 0.9995993375778198\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.97215545"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(pruned_val_program)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
